/*
Copyright 2020 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package awsup

import (
	"bytes"
	"context"
	"crypto/sha256"
	"encoding/base64"
	"encoding/hex"
	"encoding/json"
	"fmt"
	"net/http"
	"net/url"
	"time"

	"github.com/aws/aws-sdk-go-v2/aws"
	v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
	awsconfig "github.com/aws/aws-sdk-go-v2/config"
	"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
	"github.com/aws/aws-sdk-go-v2/service/sts"
	smithyhttp "github.com/aws/smithy-go/transport/http"
	"k8s.io/kops/pkg/bootstrap"
)

const AWSAuthenticationTokenPrefixV1 = "x-aws-sts "
const AWSAuthenticationTokenPrefixV2 = "x-aws-sts-v2 "

type awsAuthenticator struct {
	// sts holds the AWS STS client, for signing V2 requests
	sts *sts.Client

	// region holds the AWS region in which we are running
	region string

	// credentialsProvider returns our AWS credentials, for sigining V1 requests
	credentialsProvider aws.CredentialsProvider
}

var _ bootstrap.Authenticator = &awsAuthenticator{}

// RegionFromMetadata returns the current region from the aws metdata
func RegionFromMetadata(ctx context.Context) (string, error) {
	cfg, err := awsconfig.LoadDefaultConfig(ctx)
	if err != nil {
		return "", fmt.Errorf("failed to load default aws config: %w", err)
	}
	metadata := imds.NewFromConfig(cfg)

	resp, err := metadata.GetRegion(ctx, &imds.GetRegionInput{})
	if err != nil {
		return "", fmt.Errorf("failed to get region from ec2 metadata: %w", err)
	}
	return resp.Region, nil
}

func NewAWSAuthenticator(ctx context.Context, region string) (bootstrap.Authenticator, error) {
	config, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region))
	if err != nil {
		return nil, fmt.Errorf("failed to load aws config: %w", err)
	}
	return &awsAuthenticator{
		credentialsProvider: config.Credentials,
		region:              region,
		sts:                 sts.NewFromConfig(config),
	}, nil
}

// awsV1Token is the format of the V1 request, it matches http.Header
type awsV1Token map[string][]string

// awsV2Token is the format of the V2 request, it maps to the http request generated by STS GetCallerIdentity
type awsV2Token struct {
	URL          string      `json:"url"`
	Method       string      `json:"method"`
	SignedHeader http.Header `json:"headers"`
}

func (a *awsAuthenticator) CreateToken(body []byte) (string, error) {
	ctx := context.TODO()

	// We sign with V1, for backwards compatibility.
	// The issue is that if we upgrade the nodes before the control plane,
	// the nodes are using v2 authentication against a v1 verifier.
	// By having the server support v1 and v2, but the nodes continue to use
	// v1 for now, we can introduce v2 support and then enable it in a few versions.
	// The "nodes before control plane" is not the common case,
	// and nodes at much higher versions is not guaranteed to be supported by kube,
	// so once we are at kOps 1.32 this shoud be safe to flip to use V2.
	// It's possibly safe at kOps 1.31 but that might need more careful analysis.
	signWithV1 := true
	if signWithV1 {
		return a.createTokenV1(ctx, body)
	}
	return a.createTokenV2(ctx, body)
}

func (a *awsAuthenticator) createTokenV1(ctx context.Context, body []byte) (string, error) {
	credentials, err := a.credentialsProvider.Retrieve(ctx)
	if err != nil {
		return "", fmt.Errorf("getting AWS credentials: %w", err)
	}

	host, err := a.getSTSHost(ctx)
	if err != nil {
		return "", fmt.Errorf("getting AWS STS url: %w", err)
	}
	stsURL := "https://" + host + "/"
	region := a.region

	req, err := signV1Request(ctx, stsURL, region, credentials, time.Now(), body)
	if err != nil {
		return "", fmt.Errorf("building (v1) signed request: %w", err)
	}
	headers, err := json.Marshal(req.Header)
	if err != nil {
		return "", fmt.Errorf("converting headers to json: %w", err)
	}
	return AWSAuthenticationTokenPrefixV1 + base64.StdEncoding.EncodeToString(headers), nil
}

func (a *awsAuthenticator) getSTSHost(ctx context.Context) (string, error) {
	// An inefficient but reliable way to get the STS url
	presignClient := sts.NewPresignClient(a.sts)
	stsRequest, err := presignClient.PresignGetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
	if err != nil {
		return "", fmt.Errorf("building AWS STS presigned request: %w", err)
	}
	u, err := url.Parse(stsRequest.URL)
	if err != nil {
		return "", fmt.Errorf("parsing AWS STS url: %w", err)
	}
	return u.Host, err
}

func (a *awsAuthenticator) createTokenV2(ctx context.Context, body []byte) (string, error) {
	sha := sha256.Sum256(body)

	presignClient := sts.NewPresignClient(a.sts)

	// Ensure the signature is only valid for this particular body content.
	stsRequest, err := presignClient.PresignGetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}, func(po *sts.PresignOptions) {
		po.ClientOptions = append(po.ClientOptions, func(o *sts.Options) {
			o.APIOptions = append(o.APIOptions, smithyhttp.AddHeaderValue("X-Kops-Request-SHA", base64.RawStdEncoding.EncodeToString(sha[:])))
		})
	})
	if err != nil {
		return "", fmt.Errorf("building AWS STS presigned request: %w", err)
	}

	awsV2Token := &awsV2Token{
		URL:          stsRequest.URL,
		Method:       stsRequest.Method,
		SignedHeader: stsRequest.SignedHeader,
	}
	token, err := json.Marshal(awsV2Token)
	if err != nil {
		return "", fmt.Errorf("converting token to json: %w", err)
	}

	return AWSAuthenticationTokenPrefixV2 + base64.StdEncoding.EncodeToString(token), nil
}

func signV1Request(ctx context.Context, stsURL string, region string, credentials aws.Credentials, signingTime time.Time, kopsRequestBody []byte) (*http.Request, error) {
	kopsRequestHash := sha256.Sum256(kopsRequestBody)
	kopsRequestHashBase64 := base64.RawStdEncoding.EncodeToString(kopsRequestHash[:])

	// V1 requests use a well-known body (and host)
	body := []byte("Action=GetCallerIdentity&Version=2011-06-15")

	bodyHash := sha256.Sum256(body)

	signedRequest, err := http.NewRequest("POST", stsURL, bytes.NewReader(body))
	if err != nil {
		return nil, fmt.Errorf("building http request: %v", err)
	}
	signedRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
	signedRequest.Header.Add("X-Kops-Request-Sha", kopsRequestHashBase64)

	signer := v4.NewSigner()

	service := "sts"

	if err := signer.SignHTTP(ctx, credentials, signedRequest, hex.EncodeToString(bodyHash[:]), service, region, signingTime); err != nil {
		return nil, fmt.Errorf("error from SignHTTP: %v", err)
	}

	return signedRequest, nil
}
